import numpy as np
from PIL import Image


# 保存单个图像张量
def save_image_tensor(image_tensor, save_path):
    if image_tensor.dim() > 4:
        raise Exception('Image tensor too big')

    # 转换为 numpy 数组并调整维度 (C, H, W) -> (H, W, C)
    image_tensor = image_tensor.cpu()
    image_array = image_tensor.numpy().transpose(1, 2, 0)
    # 将范围从 [0, 1] 转换为 [0, 255]
    # image_array = (image_array * 255).astype(np.uint8)

    # 如果是灰度图像，移除通道维度
    if image_array.shape[2] == 1:
        image_array = image_array.squeeze(axis=2)

    # 创建 PIL 图像对象并保存
    image = Image.fromarray(image_array)
    image.save(save_path)
